
import os
import argparse
import cv2
import random
import numpy as np
import imghdr
from copy import deepcopy

# import paddle
import torch

from PIL import Image, ImageDraw, ImageFont


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    # paddle.seed(seed)
    torch.manual_seed(seed)


def get_bio_label_maps(label_map_path):
    with open(label_map_path, "r", encoding='utf-8') as fin:
        lines = fin.readlines()
    lines = [line.strip() for line in lines]
    if "O" not in lines:
        lines.insert(0, "O")
    labels = []
    for line in lines:
        if line == "O":
            labels.append("O")
        else:
            labels.append("B-" + line)
            labels.append("I-" + line)
    label2id_map = {label: idx for idx, label in enumerate(labels)}
    id2label_map = {idx: label for idx, label in enumerate(labels)}
    return label2id_map, id2label_map


def get_image_file_list(img_file):
    imgs_lists = []
    if img_file is None or not os.path.exists(img_file):
        raise Exception("not found any img file in {}".format(img_file))

    img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'GIF'}
    if os.path.isfile(img_file) and imghdr.what(img_file) in img_end:
        imgs_lists.append(img_file)
    elif os.path.isdir(img_file):
        for single_file in os.listdir(img_file):
            file_path = os.path.join(img_file, single_file)
            if os.path.isfile(file_path) and imghdr.what(file_path) in img_end:
                imgs_lists.append(file_path)
    if len(imgs_lists) == 0:
        raise Exception("not found any img file in {}".format(img_file))
    imgs_lists = sorted(imgs_lists)
    return imgs_lists


def draw_ser_results(image,
                     ocr_results,
                     font_path="../../doc/fonts/simfang.ttf",
                     font_size=18):
    np.random.seed(2021)
    color = (np.random.permutation(range(255)),
             np.random.permutation(range(255)),
             np.random.permutation(range(255)))
    color_map = {
        idx: (color[0][idx], color[1][idx], color[2][idx])
        for idx in range(1, 255)
    }
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    img_new = image.copy()
    draw = ImageDraw.Draw(img_new)

    font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
    for ocr_info in ocr_results:
        if ocr_info["pred_id"] not in color_map:
            continue
        color = color_map[ocr_info["pred_id"]]
        text = "{}: {}".format(ocr_info["pred"], ocr_info["text"])

        draw_box_txt(ocr_info["bbox"], text, draw, font, font_size, color)

    img_new = Image.blend(image, img_new, 0.5)
    return np.array(img_new)


def draw_box_txt(bbox, text, draw, font, font_size, color):
    # draw ocr results outline
    bbox = ((bbox[0], bbox[1]), (bbox[2], bbox[3]))
    draw.rectangle(bbox, fill=color)

    # draw ocr results
    start_y = max(0, bbox[0][1] - font_size)
    tw = font.getsize(text)[0]
    draw.rectangle(
        [(bbox[0][0] + 1, start_y), (bbox[0][0] + tw + 1, start_y + font_size)],
        fill=(0, 0, 255))
    draw.text((bbox[0][0] + 1, start_y), text, fill=(255, 255, 255), font=font)


def draw_re_results(image,
                    result,
                    font_path="../../doc/fonts/simfang.ttf",
                    font_size=18):
    np.random.seed(0)
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    img_new = image.copy()
    draw = ImageDraw.Draw(img_new)

    font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
    color_head = (0, 0, 255)
    color_tail = (255, 0, 0)
    color_line = (0, 255, 0)

    for ocr_info_head, ocr_info_tail in result:
        draw_box_txt(ocr_info_head["bbox"], ocr_info_head["text"], draw, font,
                     font_size, color_head)
        draw_box_txt(ocr_info_tail["bbox"], ocr_info_tail["text"], draw, font,
                     font_size, color_tail)

        center_head = (
            (ocr_info_head['bbox'][0] + ocr_info_head['bbox'][2]) // 2,
            (ocr_info_head['bbox'][1] + ocr_info_head['bbox'][3]) // 2)
        center_tail = (
            (ocr_info_tail['bbox'][0] + ocr_info_tail['bbox'][2]) // 2,
            (ocr_info_tail['bbox'][1] + ocr_info_tail['bbox'][3]) // 2)

        draw.line([center_head, center_tail], fill=color_line, width=5)

    img_new = Image.blend(image, img_new, 0.5)
    return np.array(img_new)


# pad sentences
def pad_sentences(tokenizer,
                  encoded_inputs,
                  max_seq_len=512,
                  pad_to_max_seq_len=True,
                  return_attention_mask=True,
                  return_token_type_ids=True,
                  return_overflowing_tokens=False,
                  return_special_tokens_mask=False):
    # Padding with larger size, reshape is carried out
    max_seq_len = (
        len(encoded_inputs["input_ids"]) // max_seq_len + 1) * max_seq_len

    needs_to_be_padded = pad_to_max_seq_len and \
        max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len

    if needs_to_be_padded:
        difference = max_seq_len - len(encoded_inputs["input_ids"])
        if tokenizer.padding_side == 'right':
            if return_attention_mask:
                encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
                    "input_ids"]) + [0] * difference
            if return_token_type_ids:
                encoded_inputs["token_type_ids"] = (
                    encoded_inputs["token_type_ids"] +
                    [tokenizer.pad_token_type_id] * difference)
            if return_special_tokens_mask:
                encoded_inputs["special_tokens_mask"] = encoded_inputs[
                    "special_tokens_mask"] + [1] * difference
            encoded_inputs["input_ids"] = encoded_inputs[
                "input_ids"] + [tokenizer.pad_token_id] * difference
            encoded_inputs["bbox"] = encoded_inputs["bbox"] + [[0, 0, 0, 0]
                                                               ] * difference
    else:
        if return_attention_mask:
            encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
                "input_ids"])

    return encoded_inputs


def split_page(encoded_inputs, max_seq_len=512):
    """
    truncate is often used in training process
    """
    for key in encoded_inputs:
        if key == 'entities':
            encoded_inputs[key] = [encoded_inputs[key]]
            continue
        # encoded_inputs[key] = paddle.to_tensor(encoded_inputs[key])
        encoded_inputs[key] = torch.as_tensor(encoded_inputs[key])
        if encoded_inputs[key].ndim <= 1:  # for input_ids, att_mask and so on
            encoded_inputs[key] = encoded_inputs[key].reshape([-1, max_seq_len])
        else:  # for bbox
            encoded_inputs[key] = encoded_inputs[key].reshape(
                [-1, max_seq_len, 4])
    return encoded_inputs


def preprocess(
        tokenizer,
        ori_img,
        ocr_info,
        img_size=(224, 224),
        pad_token_label_id=-100,
        max_seq_len=512,
        add_special_ids=False,
        return_attention_mask=True, ):
    ocr_info = deepcopy(ocr_info)
    height = ori_img.shape[0]
    width = ori_img.shape[1]

    img = cv2.resize(ori_img, img_size).transpose([2, 0, 1]).astype(np.float32)

    segment_offset_id = []
    words_list = []
    bbox_list = []
    input_ids_list = []
    token_type_ids_list = []
    entities = []

    for info in ocr_info:
        # x1, y1, x2, y2
        bbox = info["bbox"]
        bbox[0] = int(bbox[0] * 1000.0 / width)
        bbox[2] = int(bbox[2] * 1000.0 / width)
        bbox[1] = int(bbox[1] * 1000.0 / height)
        bbox[3] = int(bbox[3] * 1000.0 / height)

        text = info["text"]
        encode_res = tokenizer.encode(
            text, pad_to_max_seq_len=False, return_attention_mask=True)

        if not add_special_ids:
            # TODO: use tok.all_special_ids to remove
            encode_res["input_ids"] = encode_res["input_ids"][1:-1]
            encode_res["token_type_ids"] = encode_res["token_type_ids"][1:-1]
            encode_res["attention_mask"] = encode_res["attention_mask"][1:-1]

        # for re
        entities.append({
            "start": len(input_ids_list),
            "end": len(input_ids_list) + len(encode_res["input_ids"]),
            "label": "O",
        })

        input_ids_list.extend(encode_res["input_ids"])
        token_type_ids_list.extend(encode_res["token_type_ids"])
        bbox_list.extend([bbox] * len(encode_res["input_ids"]))
        words_list.append(text)
        segment_offset_id.append(len(input_ids_list))

    encoded_inputs = {
        "input_ids": input_ids_list,
        "token_type_ids": token_type_ids_list,
        "bbox": bbox_list,
        "attention_mask": [1] * len(input_ids_list),
        "entities": entities
    }

    encoded_inputs = pad_sentences(
        tokenizer,
        encoded_inputs,
        max_seq_len=max_seq_len,
        return_attention_mask=return_attention_mask)

    encoded_inputs = split_page(encoded_inputs)

    fake_bs = encoded_inputs["input_ids"].shape[0]

    # encoded_inputs["image"] = paddle.to_tensor(img).unsqueeze(0).expand(
    #     [fake_bs] + list(img.shape))
    encoded_inputs["image"] = torch.as_tensor(img).unsqueeze(0).expand(
        [fake_bs] + list(img.shape))

    encoded_inputs["segment_offset_id"] = segment_offset_id

    return encoded_inputs


def postprocess(attention_mask, preds, id2label_map):
    # if isinstance(preds, paddle.Tensor):
    if isinstance(preds, torch.Tensor):
        preds = preds.numpy()
    preds = np.argmax(preds, axis=2)

    preds_list = [[] for _ in range(preds.shape[0])]

    # keep batch info
    for i in range(preds.shape[0]):
        for j in range(preds.shape[1]):
            if attention_mask[i][j] == 1:
                preds_list[i].append(id2label_map[preds[i][j]])

    return preds_list


def merge_preds_list_with_ocr_info(ocr_info, segment_offset_id, preds_list,
                                   label2id_map_for_draw):
    # must ensure the preds_list is generated from the same image
    preds = [p for pred in preds_list for p in pred]

    id2label_map = dict()
    for key in label2id_map_for_draw:
        val = label2id_map_for_draw[key]
        if key == "O":
            id2label_map[val] = key
        if key.startswith("B-") or key.startswith("I-"):
            id2label_map[val] = key[2:]
        else:
            id2label_map[val] = key

    for idx in range(len(segment_offset_id)):
        if idx == 0:
            start_id = 0
        else:
            start_id = segment_offset_id[idx - 1]

        end_id = segment_offset_id[idx]

        curr_pred = preds[start_id:end_id]
        curr_pred = [label2id_map_for_draw[p] for p in curr_pred]

        if len(curr_pred) <= 0:
            pred_id = 0
        else:
            counts = np.bincount(curr_pred)
            pred_id = np.argmax(counts)
        ocr_info[idx]["pred_id"] = int(pred_id)
        ocr_info[idx]["pred"] = id2label_map[int(pred_id)]
    return ocr_info


def print_arguments(args, logger=None):
    print_func = logger.info if logger is not None else print
    """print arguments"""
    print_func('-----------  Configuration Arguments -----------')
    for arg, value in sorted(vars(args).items()):
        print_func('%s: %s' % (arg, value))
    print_func('------------------------------------------------')


def parse_args():
    parser = argparse.ArgumentParser()
    # Required parameters
    # yapf: disable
    parser.add_argument("--model_name_or_path",
                        default=None, type=str, required=True,)
    parser.add_argument("--ser_model_type",
                        default='LayoutXLM', type=str)
    parser.add_argument("--re_model_name_or_path",
                        default=None, type=str, required=False,)
    parser.add_argument("--train_data_dir", default=None,
                        type=str, required=False,)
    parser.add_argument("--train_label_path", default=None,
                        type=str, required=False,)
    parser.add_argument("--eval_data_dir", default=None,
                        type=str, required=False,)
    parser.add_argument("--eval_label_path", default=None,
                        type=str, required=False,)
    parser.add_argument("--output_dir", default=None, type=str, required=True,)
    parser.add_argument("--max_seq_length", default=512, type=int,)
    parser.add_argument("--evaluate_during_training", action="store_true",)
    parser.add_argument("--num_workers", default=8, type=int,)
    parser.add_argument("--per_gpu_train_batch_size", default=8,
                        type=int, help="Batch size per GPU/CPU for training.",)
    parser.add_argument("--per_gpu_eval_batch_size", default=8,
                        type=int, help="Batch size per GPU/CPU for eval.",)
    parser.add_argument("--learning_rate", default=5e-5,
                        type=float, help="The initial learning rate for Adam.",)
    parser.add_argument("--weight_decay", default=0.0,
                        type=float, help="Weight decay if we apply some.",)
    parser.add_argument("--adam_epsilon", default=1e-8,
                        type=float, help="Epsilon for Adam optimizer.",)
    parser.add_argument("--max_grad_norm", default=1.0,
                        type=float, help="Max gradient norm.",)
    parser.add_argument("--num_train_epochs", default=3, type=int,
                        help="Total number of training epochs to perform.",)
    parser.add_argument("--warmup_steps", default=0, type=int,
                        help="Linear warmup over warmup_steps.",)
    parser.add_argument("--eval_steps", type=int, default=10,
                        help="eval every X updates steps.",)
    parser.add_argument("--seed", type=int, default=2048,
                        help="random seed for initialization",)

    parser.add_argument("--rec_model_dir", default=None, type=str, )
    parser.add_argument("--det_model_dir", default=None, type=str, )
    parser.add_argument(
        "--label_map_path", default="./labels/labels_ser.txt", type=str, required=False, )
    parser.add_argument("--infer_imgs", default=None, type=str, required=False)
    parser.add_argument("--resume", action='store_true')
    parser.add_argument("--ocr_json_path", default=None,
                        type=str, required=False, help="ocr prediction results")

    # ocr
    def str2bool(v):
        return v.lower() in ("true", "t", "1")

    parser.add_argument("--use_gpu", type=str2bool, default=True)
    # parser.add_argument("--ir_optim", type=str2bool, default=True)
    # parser.add_argument("--use_tensorrt", type=str2bool, default=False)
    # parser.add_argument("--use_fp16", type=str2bool, default=False)
    parser.add_argument("--gpu_mem", type=int, default=500)

    # params for text detector
    # parser.add_argument("--image_dir", type=str)
    parser.add_argument("--det_algorithm", type=str, default='DB')
    parser.add_argument("--det_model_path", type=str)
    parser.add_argument("--det_limit_side_len", type=float, default=960)
    parser.add_argument("--det_limit_type", type=str, default='max')

    # DB parmas
    parser.add_argument("--det_db_thresh", type=float, default=0.3)
    parser.add_argument("--det_db_box_thresh", type=float, default=0.5)
    parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6)
    parser.add_argument("--max_batch_size", type=int, default=10)
    parser.add_argument("--use_dilation", type=bool, default=False)
    parser.add_argument("--det_db_score_mode", type=str, default="fast")

    # EAST parmas
    parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
    parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
    parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)

    # SAST parmas
    parser.add_argument("--det_sast_score_thresh", type=float, default=0.5)
    parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2)
    parser.add_argument("--det_sast_polygon", type=bool, default=False)

    # params for text recognizer
    parser.add_argument("--rec_algorithm", type=str, default='CRNN')
    parser.add_argument("--rec_model_path", type=str)
    parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
    parser.add_argument("--rec_char_type", type=str, default='ch')
    parser.add_argument("--rec_batch_num", type=int, default=6)
    parser.add_argument("--max_text_length", type=int, default=25)

    parser.add_argument("--use_space_char", type=str2bool, default=True)
    parser.add_argument("--drop_score", type=float, default=0.5)
    parser.add_argument("--limited_max_width", type=int, default=1280)
    parser.add_argument("--limited_min_width", type=int, default=16)

    parser.add_argument(
        "--vis_font_path", type=str,
        default=os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),
                             'doc/fonts/simfang.ttf'))
    parser.add_argument(
        "--rec_char_dict_path",
        type=str,
        default=os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),
                             'pytorchocr/utils/ppocr_keys_v1.txt'))

    # params for text classifier
    parser.add_argument("--use_angle_cls", type=str2bool, default=False)
    parser.add_argument("--cls_model_path", type=str)
    parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
    parser.add_argument("--label_list", type=list, default=['0', '180'])
    parser.add_argument("--cls_batch_num", type=int, default=6)
    parser.add_argument("--cls_thresh", type=float, default=0.9)

    parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
    parser.add_argument("--use_pdserving", type=str2bool, default=False)

    # params for e2e
    parser.add_argument("--e2e_algorithm", type=str, default='PGNet')
    parser.add_argument("--e2e_model_path", type=str)
    parser.add_argument("--e2e_limit_side_len", type=float, default=768)
    parser.add_argument("--e2e_limit_type", type=str, default='max')

    # params .yaml
    parser.add_argument("--det_yaml_path", type=str, default=None)
    parser.add_argument("--rec_yaml_path", type=str, default=None)
    parser.add_argument("--cls_yaml_path", type=str, default=None)

    # yapf: enable
    args = parser.parse_args()
    return args
